package cool.mtc.web.result;

import cool.mtc.core.exception.CustomException;
import cool.mtc.core.exception.ParamException;
import cool.mtc.core.result.Result;
import cool.mtc.core.result.ResultConstant;
import cool.mtc.core.util.CollectionUtil;
import cool.mtc.core.util.StringUtil;
import cool.mtc.web.WebProperties;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.MethodParameter;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.HttpMessageNotReadableException;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.validation.BindException;
import org.springframework.validation.BindingResult;
import org.springframework.validation.FieldError;
import org.springframework.web.bind.MethodArgumentNotValidException;
import org.springframework.web.bind.MissingServletRequestParameterException;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyAdvice;

import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * @author 明河
 */
@Slf4j
@RestControllerAdvice
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class GlobalResultWrapper implements ResponseBodyAdvice<Object> {
    private static final Set<String> SUPPORT_PACKAGE_SET = new HashSet<>();
    private static final Set<String> UN_SUPPORT_PACKAGE_SET = new HashSet<>();

    private final WebProperties webProperties;
    private final ResultHandler resultHandler;

    @Override
    public boolean supports(MethodParameter returnType, Class<? extends HttpMessageConverter<?>> converterType) {
        String[] basePackages = webProperties.getGlobalResultWrapperBasePackages();
        if (null == basePackages || basePackages.length == 0) {
            return true;
        }
        String packageName = returnType.getDeclaringClass().getPackage().getName();
        if (SUPPORT_PACKAGE_SET.contains(packageName)) {
            return true;
        }
        if (UN_SUPPORT_PACKAGE_SET.contains(packageName)) {
            return false;
        }
        for (String basePackage : basePackages) {
            if (packageName.startsWith(basePackage)) {
                SUPPORT_PACKAGE_SET.add(packageName);
                return true;
            }
        }
        UN_SUPPORT_PACKAGE_SET.add(packageName);
        return false;
    }

    @SneakyThrows
    @Override
    public Object beforeBodyWrite(Object body, MethodParameter returnType, MediaType selectedContentType, Class<? extends HttpMessageConverter<?>> selectedConverterType, ServerHttpRequest request, ServerHttpResponse response) {
        response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
        return body instanceof Result ? body : Result.ofData(body);
    }

    @ExceptionHandler(value = ParamException.class)
    public Result<Object> handleException(ParamException ex) {
        return this.handleExceptionAndPrintLog(ex, ResultConstant.A0400, ex.getMessage());
    }

    @ExceptionHandler(value = CustomException.class)
    public Result<Object> handleException(CustomException ex) {
        return this.handleExceptionAndPrintLog(ex, ex.getResult());
    }

    @ExceptionHandler(value = Exception.class)
    public Result<Object> handleException(Exception ex) {
        return this.handleExceptionAndPrintLog(ex, ResultConstant.ERROR, ex.toString());
    }

    @ExceptionHandler(value = BindException.class)
    public Result<Object> handleException(BindException ex) {
        String detailErrMsg = this.getBindingResultErrMsg(ex.getBindingResult());
        return this.handleExceptionAndPrintLog(ex, ResultConstant.A0400, detailErrMsg);
    }

    @ExceptionHandler(value = MethodArgumentNotValidException.class)
    public Result<Object> handleException(MethodArgumentNotValidException ex) {
        String detailErrMsg = this.getBindingResultErrMsg(ex.getBindingResult());
        return this.handleExceptionAndPrintLog(ex, ResultConstant.A0400, detailErrMsg);
    }

    @ExceptionHandler(value = HttpMessageNotReadableException.class)
    public Result<Object> handleException(HttpMessageNotReadableException ex) {
        return this.handleExceptionAndPrintLog(ex, ResultConstant.A0400, ex.getMessage());
    }

    @ExceptionHandler(value = MissingServletRequestParameterException.class)
    public Result<Object> handleException(MissingServletRequestParameterException ex) {
        String detailErrMsg = String.format("缺少参数[%s]", ex.getParameterName());
        return this.handleExceptionAndPrintLog(ex, ResultConstant.A0400, detailErrMsg);
    }

    private String getBindingResultErrMsg(BindingResult bindingResult) {
        List<FieldError> errList = bindingResult.getFieldErrors();
        if (CollectionUtil.isEmpty(errList)) {
            return StringUtil.EMPTY;
        }
        StringBuilder errMsg = new StringBuilder();
        for (FieldError item : errList) {
            errMsg.append("[")
                    .append(item.getField())
                    .append("]")
                    .append(item.getDefaultMessage())
                    .append("，");
        }
        errMsg.deleteCharAt(errMsg.length() - 1);
        errMsg.append("。");
        return errMsg.toString();
    }

    private Result<Object> handleExceptionAndPrintLog(Exception ex, Result<Object> result) {
        resultHandler.handleMsgArgs(result);
        this.printLog(ex, result);
        return result;
    }

    private Result<Object> handleExceptionAndPrintLog(Exception ex, Result<Object> result, String detailMsg) {
        Result<Object> res = result.newInstance().msg(detailMsg);
        resultHandler.handleMsgArgs(res);

        // 打印日志
        this.printLog(ex, res);

        if (webProperties.getShowExceptionDetail()) {
            return res;
        }
        return result;
    }

    /**
     * 打印日志
     */
    private <T> void printLog(Exception ex, Result<T> result) {
        String logContent = String.format("%s（%s-%s）", ex.getClass().toString(), result.getCode(), result.getMsg());
        log.error(logContent, ex);
    }
}
